/*
*  This file is part of ygg-brute
*  Copyright (c) 2020 ygg-brute authors
*  See LICENSE for licensing information
*/

#include <cstdint>
#include <cstddef>

#include "cuda/mul_wide.cuh"

#define FE_SIZE 8

#define REG_SPEC_4(S, R1, R2, R3, R4) S(R1), S(R2), S(R3), S(R4)
#define REG_SPEC_8(S, R1, R2, R3, R4, R5, R6, R7, R8) REG_SPEC_4(S, R1, R2, R3, R4), REG_SPEC_4(S, R5, R6, R7, R8)

using FieldElement = uint32_t[8];

#define fe_copy(DST, SRC)                      \
    do                                         \
    {                                          \
        for (unsigned i = 0; i < FE_SIZE; ++i) \
            (DST)[i] = (SRC)[i];               \
    } while (0)

__host__ __device__ void fe_one(FieldElement x)
{
    x[0] = 1;
    for(auto i = 1; i < 8; ++i)
        x[i] = 0;
}

__host__ __device__ void fe_zero(FieldElement x)
{
    for(auto i = 0; i < 8; ++i)
        x[i] = 0;
}

__device__ inline void fe_reduce(FieldElement z)
{
    asm(R"({
        .reg.u32 t0, t1;
        sub.cc.u32      %0, %8, 0xffffffed;
        subc.cc.u32     %1, %9, 0xffffffff;
        subc.cc.u32     %2, %10, 0xffffffff;
        subc.cc.u32     %3, %11, 0xffffffff;
        subc.cc.u32     %4, %12, 0xffffffff;
        subc.cc.u32     %5, %13, 0xffffffff;
        subc.cc.u32     %6, %14, 0xffffffff;
        subc.cc.u32     %7, %15, 0x7fffffff;
        subc.u32        t0,  0, 0;
        and.b32         t1,  t0, 0xffffffed;
        add.cc.u32      %0,  %0, t1;
        addc.cc.u32     %1,  %1, t0;
        addc.cc.u32     %2,  %2, t0;
        addc.cc.u32     %3,  %3, t0;
        addc.cc.u32     %4,  %4, t0;
        addc.cc.u32     %5,  %5, t0;
        addc.cc.u32     %6,  %6, t0;
        shr.b32         t0,  t0, 1;
        addc.u32        %7,  %7, t0;

        sub.cc.u32      %0, %8, 0xffffffed;
        subc.cc.u32     %1, %9, 0xffffffff;
        subc.cc.u32     %2, %10, 0xffffffff;
        subc.cc.u32     %3, %11, 0xffffffff;
        subc.cc.u32     %4, %12, 0xffffffff;
        subc.cc.u32     %5, %13, 0xffffffff;
        subc.cc.u32     %6, %14, 0xffffffff;
        subc.cc.u32     %7, %15, 0x7fffffff;
        subc.u32        t0,  0, 0;
        and.b32         t1,  t0, 0xffffffed;
        add.cc.u32      %0,  %0, t1;
        addc.cc.u32     %1,  %1, t0;
        addc.cc.u32     %2,  %2, t0;
        addc.cc.u32     %3,  %3, t0;
        addc.cc.u32     %4,  %4, t0;
        addc.cc.u32     %5,  %5, t0;
        addc.cc.u32     %6,  %6, t0;
        shr.b32         t0,  t0, 1;
        addc.u32        %7,  %7, t0;
    })"
    // %0 - %7
    : REG_SPEC_8("=r", z[0], z[1], z[2], z[3], z[4], z[5], z[6], z[7])
    // %8 - %15
    : REG_SPEC_8("r", z[0], z[1], z[2], z[3], z[4], z[5], z[6], z[7])
    );
}

__device__ inline void fe_add(FieldElement z, const FieldElement x, const FieldElement y)
{
    asm(R"({
        .reg.u32       t;
        add.cc.u32     %0, %8,  %16;
        addc.cc.u32    %1, %9,  %17;
        addc.cc.u32    %2, %10, %18;
        addc.cc.u32    %3, %11, %19;
        addc.cc.u32    %4, %12, %20;
        addc.cc.u32    %5, %13, %21;
        addc.cc.u32    %6, %14, %22;
        addc.cc.u32    %7, %15, %23;
        addc.u32       t,  0,   0;
        mul.lo.u32     t,  t,   38;
        add.cc.u32     %0, %0,  t;
        addc.cc.u32    %1, %1,  0;
        addc.cc.u32    %2, %2,  0;
        addc.cc.u32    %3, %3,  0;
        addc.cc.u32    %4, %4,  0;
        addc.cc.u32    %5, %5,  0;
        addc.cc.u32    %6, %6,  0;
        addc.u32       %7, %7,  0;
    })"
    // %0 - %7
    : REG_SPEC_8("=r", z[0], z[1], z[2], z[3], z[4], z[5], z[6], z[7])
    // %8 - %15
    : REG_SPEC_8("r", x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]),
    // %16 - %23
      REG_SPEC_8("r", y[0], y[1], y[2], y[3], y[4], y[5], y[6], y[7])
    );
}

__device__ inline void fe_sub(FieldElement z, const FieldElement x, const FieldElement y)
{
    asm(R"({
        .reg.u32 t;
        sub.cc.u32     %0, %8,  %16;
        subc.cc.u32    %1, %9,  %17;
        subc.cc.u32    %2, %10, %18;
        subc.cc.u32    %3, %11, %19;
        subc.cc.u32    %4, %12, %20;
        subc.cc.u32    %5, %13, %21;
        subc.cc.u32    %6, %14, %22;
        subc.cc.u32    %7, %15, %23;
        subc.u32       t,  0,   0;
        and.b32        t,  t,   38;
        sub.cc.u32     %0, %0,  t;
        subc.cc.u32    %1, %1,  0;
        subc.cc.u32    %2, %2,  0;
        subc.cc.u32    %3, %3,  0;
        subc.cc.u32    %4, %4,  0;
        subc.cc.u32    %5, %5,  0;
        subc.cc.u32    %6, %6,  0;
        subc.u32       %7, %7,  0;
    })"
    // %0 - %7
    : REG_SPEC_8("=r", z[0], z[1], z[2], z[3], z[4], z[5], z[6], z[7])
    // %8 - %15
    : REG_SPEC_8("r", x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]),
    // %16 - %23
      REG_SPEC_8("r", y[0], y[1], y[2], y[3], y[4], y[5], y[6], y[7])
    );
}

__device__ inline void fe_mul(FieldElement z, const FieldElement x, const FieldElement y)
{
    uint32_t t[16];
    mul_wide(t, x, y);

    asm(R"({
        mad.lo.cc.u32   %0, 38, %8, %16;
        madc.hi.u32     %16, 38, %8, 0;
        mad.lo.cc.u32   %1, 38, %9, %16;
        madc.hi.u32     %16, 38, %9, 0;
        add.cc.u32      %1, %1, %17;
        madc.lo.cc.u32  %2, 38, %10, %16;
        madc.hi.u32     %16, 38, %10, 0;
        add.cc.u32      %2, %2, %18;
        madc.lo.cc.u32  %3, 38, %11, %16;
        madc.hi.u32     %16, 38, %11, 0;
        add.cc.u32      %3, %3, %19;
        madc.lo.cc.u32  %4, 38, %12, %16;
        madc.hi.u32     %16, 38, %12, 0;
        add.cc.u32      %4, %4, %20;
        madc.lo.cc.u32  %5, 38, %13, %16;
        madc.hi.u32     %16, 38, %13, 0;
        add.cc.u32      %5, %5, %21;
        madc.lo.cc.u32  %6, 38, %14, %16;
        madc.hi.u32     %16, 38, %14, 0;
        add.cc.u32      %6, %6, %22;
        madc.lo.cc.u32  %7, 38, %15, %16;
        madc.hi.u32     %16, 38, %15, 0;
        add.cc.u32      %7, %7, %23;
        addc.u32        %16, %16, 0;
        mul.lo.u32      %16, %16, 38;
        add.cc.u32      %0, %0, %16;
        addc.u32        %1, %1,  0;
        })"
        // %0 - %7
        : REG_SPEC_8("=r", z[0], z[1], z[2], z[3], z[4], z[5], z[6], z[7])
        // %8 - %15
        : REG_SPEC_8("r", t[8], t[9], t[10], t[11], t[12], t[13], t[14], t[15]),
        // %16 - %23
          REG_SPEC_8("r", t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7])
        );
}

__device__ inline void fe_square(FieldElement z, const FieldElement x)
{
    fe_mul(z, x, x);
}

__device__ inline void fe_invert(FieldElement z, const FieldElement x)
{
	FieldElement t0;
	FieldElement t1;
	FieldElement t2;
	FieldElement t3;
	int     i;

	fe_square(t0, x);
	fe_square(t1, t0);
	fe_square(t1, t1);
	fe_mul(t1, x, t1);
	fe_mul(t0, t0, t1);
	fe_square(t2, t0);
	fe_mul(t1, t1, t2);
	fe_square(t2, t1);
	for (i = 1; i < 5; ++i) {
		fe_square(t2, t2);
	}
	fe_mul(t1, t2, t1);
	fe_square(t2, t1);
	for (i = 1; i < 10; ++i) {
		fe_square(t2, t2);
	}
	fe_mul(t2, t2, t1);
	fe_square(t3, t2);
	for (i = 1; i < 20; ++i) {
		fe_square(t3, t3);
	}
	fe_mul(t2, t3, t2);
	for (i = 1; i < 11; ++i) {
		fe_square(t2, t2);
	}
	fe_mul(t1, t2, t1);
	fe_square(t2, t1);
	for (i = 1; i < 50; ++i) {
		fe_square(t2, t2);
	}
	fe_mul(t2, t2, t1);
	fe_square(t3, t2);
	for (i = 1; i < 100; ++i) {
		fe_square(t3, t3);
	}
	fe_mul(t2, t3, t2);
	for (i = 1; i < 51; ++i) {
		fe_square(t2, t2);
	}
	fe_mul(t1, t2, t1);
	for (i = 1; i < 6; ++i) {
		fe_square(t1, t1);
	}
	fe_mul(z, t1, t0);
}

DECLSPEC inline void fe_to_bytes(uint8_t s[32], const FieldElement x)
{
    memcpy(s, x, 32);
}